{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset open_book_qa_dataset (/Users/robertpraas/.cache/huggingface/datasets/open_book_qa_dataset/thoughtsource/1.0.0/b1163057b19c0b0be4f0f1104103cb49d35de01177baec5250adb5246116272f)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "231160c021ef41ff9caa96f27104d61a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02c6dd7c8a5f4a3db8da67d81eb01d98",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4957 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b0bd0805b25042578d3b99a30fbccce4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/500 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "183b07e56ecb45ca812a7c60b76349f3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/500 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "| Name         | Train   | Valid   |   Test |\n",
       "|--------------|---------|---------|--------|\n",
       "| open_book_qa | -       | -       |      1 |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'commonsense_qa', 'entailment_bank', 'gsm8k', 'mawps', 'med_qa', 'med_qa_open', 'medmc_qa', '_init_', 'mmlu_clinical_knowledge', 'mmlu_college_biology', 'mmlu_college_medicine', 'mmlu_medical_genetics', 'mmlu_professional_medicine', '_init_', 'mmlu_anatomy', 'pubmed_qa', 'qed', 'strategy_qa', 'svamp', 'worldtree']"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load data\n",
    "from cot import Collection\n",
    "coll = Collection(\"open_book_qa\", load_pregenerated_cots=False)\n",
    "coll = coll.select(split=\"test\", number_samples=1)\n",
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load local model\n",
    "\n",
    "from langchain.llms import HuggingFacePipeline\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-small\")\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(\"google/flan-t5-small\")\n",
    "\n",
    "pipe = pipeline(\n",
    "    \"text2text-generation\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    max_length=100\n",
    ")\n",
    "\n",
    "local_llm = HuggingFacePipeline(pipeline=pipe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Create a sequential langchain consisting of one for reasoning generation and the \n",
    "other for answer extraction. Inspired by https://js.langchain.com/docs/modules/chains/sequential_chain\"\"\"\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain.chains.llm import LLMChain\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Reasoning chain\"\"\"\n",
    "\n",
    "llm = local_llm\n",
    "\n",
    "template = \"\"\"{instruction}\n",
    "\n",
    "Question: {question}\n",
    "Answer_choices: {answer_choices}\n",
    "\n",
    "{cot_trigger}\n",
    "\"\"\"\n",
    "\n",
    "prompt_template = PromptTemplate(input_variables=[\"instruction\",\"question\",\"answer_choices\",\"cot_trigger\"], template=template)\n",
    "cot_chain = LLMChain(llm=llm, prompt=prompt_template,output_key=\"cot\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"answer extraction chain\"\"\"\n",
    "\n",
    "extraction_template = \"\"\"{instruction}\n",
    "\n",
    "Question: {question}\n",
    "Answer_choices: {answer_choices}\n",
    "\n",
    "Cot: {cot_trigger}{cot}\n",
    "{answer_extraction}\n",
    "\"\"\"\n",
    "\n",
    "prompt_template = PromptTemplate(input_variables=[\"instruction\",\"question\",\"answer_choices\",\"cot_trigger\",\"cot\",\"answer_extraction\"], template=extraction_template)\n",
    "answer_chain = LLMChain(llm=llm, prompt=prompt_template,output_key=\"predicted_answer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"combination of generation and extraction chain\"\"\"\n",
    "\n",
    "from langchain.chains import SequentialChain\n",
    "overall_chain = SequentialChain(chains=[cot_chain, answer_chain],\n",
    "                                input_variables=[\"instruction\",\"question\",\"answer_choices\",\"cot_trigger\",\"answer_extraction\"],\n",
    "                                output_variables=[\"cot\", \"predicted_answer\"],\n",
    "                                verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f6ea29809e64abdb6f6d1c2ff53e433",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new SequentialChain chain...\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "{'id': 'open_book_qa_test_432', 'ref_id': '', 'question': 'How many times would someone change the page of a calendar in a year?', 'type': 'multiplechoice', 'choices': ['13', '12', '15', '14'], 'context': '', 'cot': ['one year is equal to 365 days'], 'answer': ['12'], 'generated_cot': [{'id': 'b2e8943b-f4a7-4d6c-a9d6-68d132d6f38b', 'fragments_version': None, 'instruction': 'Be faithful and a little hopeful', 'cot_trigger': \"Answer: Let's think step by step.\", 'cot_trigger_template': '', 'prompt_text': '', 'cot': 'C', 'answers': [{'id': 'd9ebc5e8-cc5d-41af-9ce6-8411a3744604', 'answer_extraction': 'Therefore, among A through D, the answer is', 'answer_extraction_template': '', 'answer_extraction_text': '', 'answer': 'A', 'answer_from_choices': '', 'correct_answer': None}], 'author': '', 'date': '2023/05/12 11:59:34', 'api_service': '', 'model': \"{'name': 'flan-T5-small', 'temperature': 0, 'max_tokens': 800}\", 'comment': 'generated and extracted', 'annotations': []}], 'feedback': []}\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Create an input dictionary to feed the chain\"\"\"\n",
    "input_dict = {'input_dict':\n",
    "              {\n",
    "                  'chain': overall_chain,\n",
    "                  \"instruction\": \"Be faithful and a little hopeful\",\n",
    "                  \"cot_trigger\": \"Answer: Let's think step by step.\",\n",
    "                  \"answer_extraction\": \"Therefore, among A through D, the answer is\",\n",
    "                  'model': \"flan-T5-small\",\n",
    "                  'temperature': 0,\n",
    "                  'max_tokens': 800\n",
    "              }\n",
    "              }\n",
    "\n",
    "\"\"\"generate_extract_flexible is called in the dataloade.py file where the collection resides.\n",
    " From there a function in the generate.py file is loaded. Next to this function you could use the functions\n",
    " {generate/extract/metareason}_flexible\"\"\"    \n",
    "coll.generate_extract_flexible(input_dict=input_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86206553eaec42b7a50025369e8f20c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'open_book_qa': {'test': {'accuracy': {'flan-T5-small': {'custom_model_title': 0.0}}}}}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate the answer given\n",
    "coll.evaluate(title='custom_model_title')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
